-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
knn retrieval evaluation #770
base: main
Are you sure you want to change the base?
Conversation
@@ -0,0 +1,128 @@ | |||
import torch as th |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a license head for each python code.
import torch as th | ||
import time | ||
import graphstorm as gs | ||
from graphstorm.utils import is_distributed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to move import graphstorm related code together.
And for import, usually the order will be:
import system/builtin libraries like os, time, etc.
import pip packages
import local codes.
pred = set(pred) | ||
|
||
overlap = len(pred & ground_truth) | ||
#if overlap > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the comments.
from graphstorm.utils import setup_device | ||
from graphstorm.model.utils import load_gsgnn_embeddings | ||
|
||
def calculate_recall(pred, ground_truth): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a description of how you compute recall in the function doc?
index_dimension = embs[config.target_ntype].size(1) | ||
# Number of clusters (higher values lead to better recall but slower search) | ||
#nlist = 750 | ||
#quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the commented codes.
index = faiss.IndexFlatIP(index_dimension) | ||
index.add(embs[config.target_ntype]) | ||
|
||
#print(scores.abs().mean()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
len_dataloader = max_num_batch = len(test_dataloader) | ||
tensor = th.tensor([len_dataloader], device=device) | ||
if is_distributed(): | ||
th.distributed.all_reduce(tensor, op=th.distributed.ReduceOp.MAX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to make it distributed?
#print(blocks[0].edges(form='uv', etype='also_buy')) | ||
#breakpoint() | ||
# print(dgl.NID) | ||
if 'also_buy' in blocks[0].etypes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this implemented specifically for amazon review?
@@ -1065,6 +1066,32 @@ def save_full_node_embeddings(g, save_embed_path, | |||
|
|||
save_shuffled_node_embeddings(shuffled_embs, save_embed_path, save_embed_format) | |||
|
|||
def load_gsgnn_embeddings(emb_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if load_pytorch_embedding is useful?
Issue #, if available:
Description of changes:
This is a draft PR because I found the recall@K is extremly low and it seems the remap get wrong from reading embeddings.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.